{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "310571f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Execute this cell to install dependencies\n",
    "%pip install sf-hamilton[visualization]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97065bbc",
   "metadata": {},
   "source": [
    "# Ray + Hamilton tutorial [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/examples/ray/hello_world/notebook.ipynb) [![GitHub badge](https://img.shields.io/badge/github-view_source-2b3137?logo=github)](https://github.com/apache/hamilton/blob/main/examples/ray/hello_world/notebook.ipynb)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2596c244",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-18T23:28:04.483891Z",
     "start_time": "2023-09-18T23:28:04.467608Z"
    },
    "collapsed": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "# We use the autoreload extension that comes with ipython to automatically reload modules when\n",
    "# the code in them changes.\n",
    "\n",
    "# import the jupyter extension\n",
    "%load_ext autoreload\n",
    "# set it to only reload the modules imported\n",
    "%autoreload 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "39bf27b9",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-18T23:28:04.484177Z",
     "start_time": "2023-09-18T23:28:04.474636Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import ray\n",
    "import pandas as pd\n",
    "\n",
    "from hamilton import base, driver\n",
    "from hamilton.plugins import h_ray"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c8f46d34",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-18T23:28:04.505831Z",
     "start_time": "2023-09-18T23:28:04.477897Z"
    },
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting spend_calculations.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile spend_calculations.py\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "def spend(spend_location: str) -> pd.Series:\n",
    "    \"\"\"Dummy function showing how to wire through loading data.\n",
    "\n",
    "    :param spend_location:\n",
    "    :return:\n",
    "    \"\"\"\n",
    "    return pd.Series([10, 10, 20, 40, 40, 50])\n",
    "\n",
    "\n",
    "def signups(signups_location: str) -> pd.Series:\n",
    "    \"\"\"Dummy function showing how to wire through loading data.\n",
    "\n",
    "    :param signups_location:\n",
    "    :return:\n",
    "    \"\"\"\n",
    "    return pd.Series([1, 10, 50, 100, 200, 400])\n",
    "\n",
    "\n",
    "def avg_3wk_spend(spend: pd.Series) -> pd.Series:\n",
    "    \"\"\"Rolling 3 week average spend.\"\"\"\n",
    "    return spend.rolling(3).mean()\n",
    "\n",
    "\n",
    "def spend_per_signup(spend: pd.Series, signups: pd.Series) -> pd.Series:\n",
    "    \"\"\"The cost per signup in relation to spend.\"\"\"\n",
    "    return spend / signups\n",
    "\n",
    "\n",
    "def spend_mean(spend: pd.Series) -> float:\n",
    "    \"\"\"Shows function creating a scalar. In this case it computes the mean of the entire column.\"\"\"\n",
    "    return spend.mean()\n",
    "\n",
    "\n",
    "def spend_zero_mean(spend: pd.Series, spend_mean: float) -> pd.Series:\n",
    "    \"\"\"Shows function that takes a scalar. In this case to zero mean spend.\"\"\"\n",
    "    return spend - spend_mean\n",
    "\n",
    "\n",
    "def spend_std_dev(spend: pd.Series) -> float:\n",
    "    \"\"\"Function that computes the standard deviation of the spend column.\"\"\"\n",
    "    return spend.std()\n",
    "\n",
    "\n",
    "def spend_zero_mean_unit_variance(spend_zero_mean: pd.Series, spend_std_dev: float) -> pd.Series:\n",
    "    \"\"\"Function showing one way to make spend have zero mean and unit variance.\"\"\"\n",
    "    return spend_zero_mean / spend_std_dev"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "3cd824c9",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-18T23:28:04.507804Z",
     "start_time": "2023-09-18T23:28:04.489070Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "%aimport spend_calculations\n",
    "\n",
    "# Set up the driver, input and output columns\n",
    "\n",
    "config = {  # could load data here via some other means, or delegate to a module as we have done.\n",
    "        # 'signups': pd.Series([1, 10, 50, 100, 200, 400]),\n",
    "        \"signups_location\": \"some_path\",\n",
    "        # 'spend': pd.Series([10, 10, 20, 40, 40, 50]),\n",
    "        \"spend_location\": \"some_other_path\",\n",
    "    }\n",
    "adapter = h_ray.RayGraphAdapter(result_builder=base.PandasDataFrameResult())\n",
    "dr = driver.Driver(config, spend_calculations, adapter=adapter)\n",
    "output_columns = [\n",
    "    \"spend\",\n",
    "    \"signups\",\n",
    "    \"avg_3wk_spend\",\n",
    "    \"spend_per_signup\",\n",
    "    \"spend_zero_mean_unit_variance\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "ee0d8a50",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-18T23:28:09.929554Z",
     "start_time": "2023-09-18T23:28:04.518243Z"
    },
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-09-18 19:28:07,329\tINFO worker.py:1621 -- Started a local Ray instance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   spend  signups  avg_3wk_spend  spend_per_signup  \\\n",
      "0     10        1            NaN            10.000   \n",
      "1     10       10            NaN             1.000   \n",
      "2     20       50      13.333333             0.400   \n",
      "3     40      100      23.333333             0.400   \n",
      "4     40      200      33.333333             0.200   \n",
      "5     50      400      43.333333             0.125   \n",
      "\n",
      "   spend_zero_mean_unit_variance  \n",
      "0                      -1.064405  \n",
      "1                      -1.064405  \n",
      "2                      -0.483821  \n",
      "3                       0.677349  \n",
      "4                       0.677349  \n",
      "5                       1.257934  \n"
     ]
    }
   ],
   "source": [
    "# Execute the driver.\n",
    "df = dr.execute(output_columns)\n",
    "print(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "e991d1f4",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-18T23:28:10.134103Z",
     "start_time": "2023-09-18T23:28:09.930093Z"
    },
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 8.1.0 (20230707.0739)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"588pt\" height=\"332pt\"\n",
       " viewBox=\"0.00 0.00 587.61 332.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 328)\">\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-328 583.61,-328 583.61,4 -4,4\"/>\n",
       "<!-- spend_zero_mean_unit_variance -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>spend_zero_mean_unit_variance</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"530.18,-36 335.68,-36 335.68,0 530.18,0 530.18,-36\"/>\n",
       "<text text-anchor=\"middle\" x=\"432.93\" y=\"-12.95\" font-family=\"Times,serif\" font-size=\"14.00\">spend_zero_mean_unit_variance</text>\n",
       "</g>\n",
       "<!-- spend_per_signup -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>spend_per_signup</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"155.05,-180 40.8,-180 40.8,-144 155.05,-144 155.05,-180\"/>\n",
       "<text text-anchor=\"middle\" x=\"97.93\" y=\"-156.95\" font-family=\"Times,serif\" font-size=\"14.00\">spend_per_signup</text>\n",
       "</g>\n",
       "<!-- spend_location -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>spend_location</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"351.93\" cy=\"-306\" rx=\"91.27\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"351.93\" y=\"-300.95\" font-family=\"Times,serif\" font-size=\"14.00\">Input: spend_location</text>\n",
       "</g>\n",
       "<!-- spend -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>spend</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"378.93,-252 324.93,-252 324.93,-216 378.93,-216 378.93,-252\"/>\n",
       "<text text-anchor=\"middle\" x=\"351.93\" y=\"-228.95\" font-family=\"Times,serif\" font-size=\"14.00\">spend</text>\n",
       "</g>\n",
       "<!-- spend_location&#45;&gt;spend -->\n",
       "<g id=\"edge7\" class=\"edge\">\n",
       "<title>spend_location&#45;&gt;spend</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M351.93,-287.7C351.93,-280.24 351.93,-271.32 351.93,-262.97\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"355.43,-263.1 351.93,-253.1 348.43,-263.1 355.43,-263.1\"/>\n",
       "</g>\n",
       "<!-- avg_3wk_spend -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>avg_3wk_spend</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"276.8,-180 173.05,-180 173.05,-144 276.8,-144 276.8,-180\"/>\n",
       "<text text-anchor=\"middle\" x=\"224.93\" y=\"-156.95\" font-family=\"Times,serif\" font-size=\"14.00\">avg_3wk_spend</text>\n",
       "</g>\n",
       "<!-- spend_mean -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>spend_mean</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"351.93\" cy=\"-162\" rx=\"57.49\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"351.93\" y=\"-156.95\" font-family=\"Times,serif\" font-size=\"14.00\">spend_mean</text>\n",
       "</g>\n",
       "<!-- spend_zero_mean -->\n",
       "<g id=\"node9\" class=\"node\">\n",
       "<title>spend_zero_mean</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"351.93\" cy=\"-90\" rx=\"77.97\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"351.93\" y=\"-84.95\" font-family=\"Times,serif\" font-size=\"14.00\">spend_zero_mean</text>\n",
       "</g>\n",
       "<!-- spend_mean&#45;&gt;spend_zero_mean -->\n",
       "<g id=\"edge11\" class=\"edge\">\n",
       "<title>spend_mean&#45;&gt;spend_zero_mean</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M351.93,-143.7C351.93,-136.24 351.93,-127.32 351.93,-118.97\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"355.43,-119.1 351.93,-109.1 348.43,-119.1 355.43,-119.1\"/>\n",
       "</g>\n",
       "<!-- spend&#45;&gt;spend_per_signup -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>spend&#45;&gt;spend_per_signup</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M324.74,-225.51C287.15,-215.15 218.03,-196.1 165.93,-181.74\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"166.99,-178.13 156.42,-178.85 165.13,-184.88 166.99,-178.13\"/>\n",
       "</g>\n",
       "<!-- spend&#45;&gt;avg_3wk_spend -->\n",
       "<g id=\"edge5\" class=\"edge\">\n",
       "<title>spend&#45;&gt;avg_3wk_spend</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M324.68,-217.98C307.6,-208.57 285.25,-196.25 266.04,-185.66\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"268.05,-182.22 257.61,-180.46 264.67,-188.35 268.05,-182.22\"/>\n",
       "</g>\n",
       "<!-- spend&#45;&gt;spend_mean -->\n",
       "<g id=\"edge6\" class=\"edge\">\n",
       "<title>spend&#45;&gt;spend_mean</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M351.93,-215.7C351.93,-208.24 351.93,-199.32 351.93,-190.97\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"355.43,-191.1 351.93,-181.1 348.43,-191.1 355.43,-191.1\"/>\n",
       "</g>\n",
       "<!-- spend_std_dev -->\n",
       "<g id=\"node8\" class=\"node\">\n",
       "<title>spend_std_dev</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" cx=\"513.93\" cy=\"-90\" rx=\"65.68\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"513.93\" y=\"-84.95\" font-family=\"Times,serif\" font-size=\"14.00\">spend_std_dev</text>\n",
       "</g>\n",
       "<!-- spend&#45;&gt;spend_std_dev -->\n",
       "<g id=\"edge9\" class=\"edge\">\n",
       "<title>spend&#45;&gt;spend_std_dev</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M379.4,-218.05C396.23,-208.28 417.76,-194.65 434.93,-180 457.34,-160.87 479.31,-135.42 494.36,-116.64\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"497.72,-119.03 501.17,-109.01 492.23,-114.69 497.72,-119.03\"/>\n",
       "</g>\n",
       "<!-- spend&#45;&gt;spend_zero_mean -->\n",
       "<g id=\"edge10\" class=\"edge\">\n",
       "<title>spend&#45;&gt;spend_zero_mean</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M379.17,-219.01C393.55,-210 409.86,-196.83 417.93,-180 424.84,-165.57 424.84,-158.43 417.93,-144 412.07,-131.78 401.87,-121.49 391.22,-113.33\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"393.43,-109.9 383.25,-106.97 389.38,-115.61 393.43,-109.9\"/>\n",
       "</g>\n",
       "<!-- signups -->\n",
       "<g id=\"node7\" class=\"node\">\n",
       "<title>signups</title>\n",
       "<polygon fill=\"none\" stroke=\"black\" points=\"126.55,-252 69.3,-252 69.3,-216 126.55,-216 126.55,-252\"/>\n",
       "<text text-anchor=\"middle\" x=\"97.93\" y=\"-228.95\" font-family=\"Times,serif\" font-size=\"14.00\">signups</text>\n",
       "</g>\n",
       "<!-- signups&#45;&gt;spend_per_signup -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>signups&#45;&gt;spend_per_signup</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M97.93,-215.7C97.93,-208.24 97.93,-199.32 97.93,-190.97\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"101.43,-191.1 97.93,-181.1 94.43,-191.1 101.43,-191.1\"/>\n",
       "</g>\n",
       "<!-- spend_std_dev&#45;&gt;spend_zero_mean_unit_variance -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>spend_std_dev&#45;&gt;spend_zero_mean_unit_variance</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M494.73,-72.41C484.74,-63.78 472.33,-53.05 461.26,-43.48\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"463.97,-40.34 454.11,-36.45 459.39,-45.63 463.97,-40.34\"/>\n",
       "</g>\n",
       "<!-- spend_zero_mean&#45;&gt;spend_zero_mean_unit_variance -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>spend_zero_mean&#45;&gt;spend_zero_mean_unit_variance</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M371.12,-72.41C381.11,-63.78 393.52,-53.05 404.59,-43.48\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"406.46,-45.63 411.74,-36.45 401.88,-40.34 406.46,-45.63\"/>\n",
       "</g>\n",
       "<!-- signups_location -->\n",
       "<g id=\"node10\" class=\"node\">\n",
       "<title>signups_location</title>\n",
       "<ellipse fill=\"none\" stroke=\"black\" stroke-dasharray=\"5,2\" cx=\"97.93\" cy=\"-306\" rx=\"97.93\" ry=\"18\"/>\n",
       "<text text-anchor=\"middle\" x=\"97.93\" y=\"-300.95\" font-family=\"Times,serif\" font-size=\"14.00\">Input: signups_location</text>\n",
       "</g>\n",
       "<!-- signups_location&#45;&gt;signups -->\n",
       "<g id=\"edge8\" class=\"edge\">\n",
       "<title>signups_location&#45;&gt;signups</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M97.93,-287.7C97.93,-280.24 97.93,-271.32 97.93,-262.97\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"101.43,-263.1 97.93,-253.1 94.43,-263.1 101.43,-263.1\"/>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x13d65f790>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# To visualize do `pip install \"sf-hamilton[visualization]\"` if you want these to work\n",
    "dr.visualize_execution(output_columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3a383222",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-09-18T23:28:11.606616Z",
     "start_time": "2023-09-18T23:28:10.130104Z"
    },
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "ray.shutdown()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
